# code/stage5_model.py

import pandas as pd
import numpy as np
import statsmodels.formula.api as smf
from paths import PROCESSED_DATA_DIR, RESULTS_DIR

# load final dataset
df = pd.read_parquet(PROCESSED_DATA_DIR / 'final_analysis_dataset.parquet')

# select variables
depend = 'egal_index'
indivs = ['SEX', 'AGE', 'educ_level', 'Source']
macros = ['log_gdp_ppp', 'gii_index', 'libdem_index']
group = ['C_ALPHAN', 'YEAR']
all_vars = [depend] + indivs + macros + group

# drop missing
model_df = df[all_vars].dropna().copy()

# transformations
model_df['AGE_sq'] = model_df['AGE'] ** 2
model_df = pd.get_dummies(
    model_df,
    columns=['educ_level', 'Source'],
    prefix=['educ', 'Source'],
    drop_first=True
)
to_center = ['AGE', 'AGE_sq'] + [f'{v}' for v in macros]
for v in to_center:
    model_df[f'{v}_c'] = model_df[v] - model_df[v].mean()

# base random intercept model
f_base = "egal_index ~ C(SEX, Treatment(2)) + AGE_c + AGE_sq_c"
m0 = smf.mixedlm(f_base, data=model_df, groups=model_df['C_ALPHAN'])
r0 = m0.fit(reml=False)
with open(RESULTS_DIR / 'model_base.txt', 'w') as f:
    f.write(r0.summary().as_text())

# macro main effects
f_macro = (
    "egal_index ~ C(SEX, Treatment(2)) + AGE_c + AGE_sq_c + "
    "educ_2 + educ_3 + Source_WVS + "
    "log_gdp_ppp_c + gii_index_c + libdem_index_c + C(YEAR)"
)
m1 = smf.mixedlm(f_macro, data=model_df, groups=model_df['C_ALPHAN'])
r1 = m1.fit(reml=False)
with open(RESULTS_DIR / 'model_macro_main.txt', 'w') as f:
    f.write(r1.summary().as_text())

# two-way interactions
f_int2 = (
    f_macro + " + "
    "gii_index_c:C(SEX, Treatment(2)) + gii_index_c:educ_2 + gii_index_c:educ_3"
)
m2 = smf.mixedlm(f_int2, data=model_df, groups=model_df['C_ALPHAN'])
r2 = m2.fit(reml=False)
with open(RESULTS_DIR / 'model_interaction2.txt', 'w') as f:
    f.write(r2.summary().as_text())

# three-way interaction (GDP * GII * SEX)
f_int3 = (
    f_int2 + " + log_gdp_ppp_c:gii_index_c:C(SEX, Treatment(2))"
)
m3 = smf.mixedlm(f_int3, data=model_df, groups=model_df['C_ALPHAN'])
r3 = m3.fit(reml=False)
with open(RESULTS_DIR / 'model_interaction3.txt', 'w') as f:
    f.write(r3.summary().as_text())
